Skip to content

Hikettei/polyhedral schedule#13

Merged
hikettei merged 9 commits intomainfrom
hikettei/polyhedral-schedule
Nov 21, 2025
Merged

Hikettei/polyhedral schedule#13
hikettei merged 9 commits intomainfrom
hikettei/polyhedral-schedule

Conversation

@hikettei
Copy link
Owner

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @hikettei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the polyhedral scheduling DSL by introducing a dedicated PolyhedralSchedule class and a stmt function for automatic access relation inference. These additions streamline the process of defining, combining, and transforming polyhedral schedules, making the API more intuitive and robust for users. The changes are demonstrated through refactored examples, showcasing a more integrated workflow for polyhedral compilation tasks.

Highlights

  • New PolyhedralSchedule Class: Introduced a new PolyhedralSchedule class to encapsulate an ISL schedule, along with its associated read and write UnionMaps, and provide methods for dependency analysis and schedule manipulation.
  • Automatic Access Relation Inference: Added a stmt function that allows for automatic inference of read and write access relations from a simple string representation of a statement (e.g., Out[n,k,h,w] = In[n,c,h,w]), simplifying the DSL usage.
  • Enhanced P.domain.finalize: The P.domain.finalize method now returns an instance of the new PolyhedralSchedule class, making it easier to manage and transform schedules with their associated access information.
  • Refactored Examples: The conv2d_pool2d_fusion.py example has been refactored to utilize the new PolyhedralSchedule API and P.stmt, demonstrating a cleaner and more integrated approach to polyhedral schedule construction and transformation. A new conv2d_pool2d_fusion_v2.py example was also added.
  • Polyhedral DSL Guidelines: Updated AGENTS.md with new guidelines for the Polyhedral DSL, specifically recommending the use of Mixin operator overloads for cleaner code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@hikettei
Copy link
Owner Author

/gemini review

@hikettei hikettei merged commit b9f7f77 into main Nov 21, 2025
4 checks passed
@hikettei hikettei deleted the hikettei/polyhedral-schedule branch November 21, 2025 10:10
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a PolyhedralSchedule class and a stmt function to create a more user-friendly DSL for polyhedral scheduling, which is a significant improvement. The refactoring of the conv2d_pool2d_fusion example effectively showcases the benefits of this new API. My review includes a critical fix for a missing function in a new example, along with several medium-severity suggestions to enhance type safety, improve code conciseness by refactoring repetitive logic, and align the new code with the newly introduced DSL guidelines by implementing operator overloading.

hikettei added a commit that referenced this pull request Nov 21, 2025
* fix: Final formatting fix

* test: Accept ruff formatting for test_fusion.py

* test: Accept ruff formatting for test_fusion.py (final)

* fix: Format domain.py to satisfy ruff

* update

* fix: Final formatting fix for example

* fix: Final format fix accept

* update

* update
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a PolyhedralSchedule class and a stmt DSL function, which significantly simplify the process of creating and manipulating polyhedral schedules. The refactoring of the conv2d_pool2d_fusion.py example demonstrates the improved readability and usability of the new API. Overall, this is a valuable addition.

However, I've identified a few issues that should be addressed. There's a critical correctness issue in the is_legal method, which could lead to incorrect schedule validation. Additionally, the test coverage threshold has been lowered, and the new modules lack dedicated unit tests, which is a concern for maintainability. I've also noted several opportunities for refactoring to reduce code duplication and to adhere to the newly introduced coding guidelines. Please see my detailed comments below.

Comment on lines +26 to +30
def is_legal(self) -> bool:
# Check legality against RAW dependencies
if self.raw_dep:
return schedule_is_legal_p(self.isl_schedule, self.raw_dep)
return True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The is_legal method currently checks for schedule legality only against Read-After-Write (RAW) dependencies. For a schedule to be truly legal, it must respect all data dependencies, including Write-After-Write (WAW) and Write-After-Read (WAR). The compute_dependencies method already calculates self.total_dep, which is the union of all dependencies. You should use self.total_dep for a correct legality check.

Suggested change
def is_legal(self) -> bool:
# Check legality against RAW dependencies
if self.raw_dep:
return schedule_is_legal_p(self.isl_schedule, self.raw_dep)
return True
def is_legal(self) -> bool:
# Check legality against all dependencies
if self.total_dep:
return schedule_is_legal_p(self.isl_schedule, self.total_dep)
return True


[tool.pytest.ini_options]
addopts = "--cov=caten --cov-report=term-missing --cov-fail-under=70"
addopts = "--cov=caten --cov-report=term-missing --cov-fail-under=60"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Lowering the test coverage threshold from 70% to 60% is a significant step backward for the project's code quality and maintainability. This change suggests that the new functionality is not adequately tested. Instead of reducing the quality bar, please add the necessary tests to meet the original coverage requirement. This ensures that new code is as reliable as the existing codebase.

Comment on lines 1 to +99
import caten.isl as I
import caten.polyhedral as P
from caten.polyhedral.analysis import compute_dependence_relation, schedule_is_legal_p
from caten.polyhedral.transformations import schedule_node_sequence_full_fuse


def create_conv_schedule(N, K_out, H_out, W_out, Cin, KH, KW):
dom_str = f"{{ S_conv[n, k, h, w, c, kh, kw] : 0<=n<{N} and 0<=k<{K_out} and 0<=h<{H_out} and 0<=w<{W_out} and 0<=c<{Cin} and 0<=kh<{KH} and 0<=kw<{KW} }}"

with P.domain(dom_str) as conv:
with P.band("{ S_conv[n, k, h, w, c, kh, kw] -> [n, k, h, w, c, kh, kw] }"):
P.stmt("Out[n, k, h, w] = Out[n, k, h, w], In[n, c, h, w], W[k, c, kh, kw]")

return conv.finalize()

def create_pool_schedule(N, K_out, H_pool, W_pool, S_pool, KH_pool, KW_pool):
dom_str = f"{{ S_pool[n, k, h, w, rh, rw] : 0<=n<{N} and 0<=k<{K_out} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}"

with P.domain(dom_str) as pool:
with P.band("{ S_pool[n, k, h, w, rh, rw] -> [n, k, h, w, rh, rw] }"):
P.stmt(f"PoolBuf[n, k, h, w] = PoolBuf[n, k, h, w], Out[n, k, h*{S_pool} + rh, w*{S_pool} + rw]")

return pool.finalize()

def test_conv2d_pool2d_fusion():
# Parameters
N = 10
Cin = 16
Cout = 32
H_in, W_in = 32, 32

KH_conv, KW_conv = 3, 3
S_conv = 1
H_conv = (H_in - KH_conv) // S_conv + 1
W_conv = (W_in - KW_conv) // S_conv + 1

KH_pool, KW_pool = 2, 2
S_pool = 2
H_pool = (H_conv - KH_pool) // S_pool + 1
W_pool = (W_conv - KW_pool) // S_pool + 1

Tile_H = S_pool
Tile_W = S_pool

# Domains
conv_dom_str = f"{{ S_conv[n, k, h, w, c, kh, kw] : 0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} and 0<=c<{Cin} and 0<=kh<{KH_conv} and 0<=kw<{KW_conv} }}"
pool_dom_str = f"{{ S_pool[n, k, h, w, rh, rw] : 0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}"

with I.context():
# Access Relations (including Reduction Dependencies)
writes_conv = I.UnionMap(
f"{{ S_conv[n, k, h, w, c, kh, kw] -> Out[n, k, h, w] : "
f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} }}"
)
reads_conv_acc = I.UnionMap(
f"{{ S_conv[n, k, h, w, c, kh, kw] -> Out[n, k, h, w] : "
f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_conv} and 0<=w<{W_conv} }}"
)

reads_pool = I.UnionMap(
f"{{ S_pool[n, k, h, w, rh, rw] -> Out[n, k, h*{S_pool} + rh, w*{S_pool} + rw] : "
f"0<=n<{N} and 0<=k<{Cout} and 0<=h<{H_pool} and 0<=w<{W_pool} and 0<=rh<{KH_pool} and 0<=rw<{KW_pool} }}"
)
writes_pool = I.UnionMap(
"{ S_pool[n, k, h, w, rh, rw] -> PoolBuf[n, k, h, w] }"
)
reads_pool_acc = I.UnionMap(
"{ S_pool[n, k, h, w, rh, rw] -> PoolBuf[n, k, h, w] }"
)

all_writes = writes_conv.union(writes_pool)
all_reads = reads_pool.union(reads_conv_acc).union(reads_pool_acc)

# Initial Schedule
filters = I.UnionSetList.alloc(2)
filters = filters.add(I.UnionSet(conv_dom_str))
filters = filters.add(I.UnionSet(pool_dom_str))

full_dom = I.UnionSet(conv_dom_str).union(I.UnionSet(pool_dom_str))
sched = I.Schedule.from_domain(full_dom)

root = sched.get_root()
seq_node = root.child(0).insert_sequence(filters)

conv_filter = seq_node.child(0)
conv_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap("{ S_conv[n, k, h, w, c, kh, kw] -> [n, k, h, w, c, kh, kw] }"))
conv_node = conv_filter.child(0).insert_partial_schedule(conv_mupa)

seq_node = conv_node.parent().parent()
pool_filter = seq_node.child(1)
pool_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap("{ S_pool[n, k, h, w, rh, rw] -> [n, k, h, w, rh, rw] }"))
pool_node = pool_filter.child(0).insert_partial_schedule(pool_mupa)

initial_sched = pool_node.get_schedule()

total_dep, raw, waw, war = compute_dependence_relation(
read=all_reads,
write=all_writes,
schedule=initial_sched
)

legal = schedule_is_legal_p(initial_sched, total_dep)
assert legal

# --- Transformations ---
def get_seq_from_band(band):
return band.parent().parent()

# Split Bands
seq_node = get_seq_from_band(pool_node)
conv_band = seq_node.child(0).child(0)
conv_band = conv_band.band_split(2)
conv_band_2 = conv_band.child(0)
conv_band_2 = conv_band_2.band_split(2)

# Use correct navigation for nested bands
seq_node = conv_band_2.parent().parent().parent()

pool_band = seq_node.child(1).child(0)
pool_band = pool_band.band_split(2)
pool_band_2 = pool_band.child(0)
pool_band_2 = pool_band_2.band_split(2)

seq_node = pool_band_2.parent().parent().parent()

# Fuse NK
nk_band = schedule_node_sequence_full_fuse(seq_node)

# Tile Conv HW
seq_node = nk_band.child(0)
conv_hw = seq_node.child(0).child(0)
conv_hw = conv_hw.band_set_permutable(1)
space = conv_hw.band_get_space()

mv = I.MultiVal.zero(space)
mv = mv.set_val(0, I.Val.int_from_si(Tile_H))
mv = mv.set_val(1, I.Val.int_from_si(Tile_W))

conv_tiled = conv_hw.band_tile(mv)
conv_tiled = conv_tiled.band_scale_down(mv)

# Replace Inner Band
inner = conv_tiled.child(0)
replaced = inner.delete()
new_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap(f"{{ S_conv[n, k, h, w, c, kh, kw] -> [(h%{Tile_H}), (w%{Tile_W})] }}"))
conv_tiled_inner = replaced.insert_partial_schedule(new_mupa) # noqa: F841

seq_node = conv_tiled_inner.parent().parent().parent()

# Fuse HW Tiles
hw_tile_band = schedule_node_sequence_full_fuse(seq_node)

# Fuse Inner
seq_node = hw_tile_band.child(0)
inner_band = schedule_node_sequence_full_fuse(seq_node)

final_sched = inner_band.get_schedule()

legal_fused = schedule_is_legal_p(final_sched, total_dep)
assert legal_fused

c_code = P.to_c(final_sched)
assert "S_conv" in c_code
assert "S_pool" in c_code
conv = create_conv_schedule(N, Cout, H_conv, W_conv, Cin, KH_conv, KW_conv)
pool = create_pool_schedule(N, Cout, H_pool, W_pool, S_pool, KH_pool, KW_pool)

psched = conv.sequence(pool)

root = psched.get_root()
seq_node = root.child(0)

# 1. Split Bands
conv_band = seq_node.child(0).child(0)
conv_band = conv_band.band_split(2) # [n, k]
conv_band_2 = conv_band.child(0)
conv_band_2 = conv_band_2.band_split(2) # [h, w]

seq_node = conv_band_2.parent().parent().parent()

pool_band = seq_node.child(1).child(0)
pool_band = pool_band.band_split(2) # [n, k]
pool_band_2 = pool_band.child(0)
pool_band_2 = pool_band_2.band_split(2) # [h, w]

seq_node = pool_band_2.parent().parent().parent()

# 2. Fuse NK
nk_band = schedule_node_sequence_full_fuse(seq_node)
psched.update(nk_band)

# 3. Tile Conv HW
seq_node = nk_band.child(0)
conv_hw = seq_node.child(0).child(0)
conv_hw = conv_hw.band_set_permutable(1)
space = conv_hw.band_get_space()
mv = I.MultiVal.zero(space)
mv = mv.set_val(0, I.Val.int_from_si(Tile_H))
mv = mv.set_val(1, I.Val.int_from_si(Tile_W))

conv_tiled = conv_hw.band_tile(mv)
conv_tiled = conv_tiled.band_scale_down(mv)

# Replace Inner Band with Relative
inner = conv_tiled.child(0)
replaced = inner.delete()
new_mupa = I.MultiUnionPwAff.from_union_map(I.UnionMap(f"{{ S_conv[n, k, h, w, c, kh, kw] -> [(h%{Tile_H}), (w%{Tile_W})] }}"))
conv_tiled_inner = replaced.insert_partial_schedule(new_mupa) # noqa: F841

seq_node = conv_tiled_inner.parent().parent().parent()

# 4. Fuse HW Tiles
hw_tile_band = schedule_node_sequence_full_fuse(seq_node)
psched.update(hw_tile_band)

# 5. Fuse Inner
seq_node = hw_tile_band.child(0)
inner_band = schedule_node_sequence_full_fuse(seq_node)
psched.update(inner_band)

assert psched.is_legal()
c_code = psched.to_c()
assert "S_conv" in c_code
assert "S_pool" in c_code

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pull request introduces significant new functionality in caten/polyhedral/poly_schedule.py and caten/polyhedral/stmt.py, but it lacks dedicated unit tests for these new modules. The refactored integration test is good for verifying the end-to-end workflow, but it doesn't cover edge cases, error handling, or individual methods of the new classes and functions. The lack of unit tests is likely the reason the coverage threshold was lowered in pyproject.toml. Please add comprehensive unit tests for the new modules to ensure their correctness and robustness.

def __init__(self) -> None:
self.current_node: Optional["I.ScheduleNode"] = None
self.schedule: Optional["I.Schedule"] = None
self.current_domain: Any = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for current_domain is Any, which is not very specific. To improve type safety and code clarity, consider using a more specific type. A forward reference Optional["domain"] would be more appropriate here, assuming it refers to the domain class from caten.polyhedral.schedule_tree.domain. You will need to add the import under a TYPE_CHECKING block to avoid circular dependencies.

Suggested change
self.current_domain: Any = None
self.current_domain: Optional["domain"] = None

Comment on lines +49 to +63
new_reads = None
if self.reads and other.reads:
new_reads = self.reads.union(other.reads)
elif self.reads:
new_reads = self.reads
elif other.reads:
new_reads = other.reads

new_writes = None
if self.writes and other.writes:
new_writes = self.writes.union(other.writes)
elif self.writes:
new_writes = self.writes
elif other.writes:
new_writes = other.writes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for combining reads and writes maps is verbose and can be simplified. Additionally, the newly added guideline in AGENTS.md recommends using operator overloads like | instead of the .union() method for better readability. This can be refactored to be more concise and adhere to the new guideline.

        new_reads = self.reads | other.reads if self.reads and other.reads else self.reads or other.reads
        new_writes = self.writes | other.writes if self.writes and other.writes else self.writes or other.writes

builder.current_domain = self._prev_domain

def finalize(self, read: Optional[Union[str, "I.UnionMap"]] = None, write: Optional[Union[str, "I.UnionMap"]] = None) -> Any:
from ..poly_schedule import PolyhedralSchedule

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of PolyhedralSchedule is performed inside the finalize method. While this can be a strategy to avoid circular imports, it doesn't seem necessary here as poly_schedule.py does not import domain.py. Moving this import to the top of the file would follow standard Python conventions and improve readability.

Comment on lines +176 to +182
if self.domain_set:
uset = self.domain_set
if isinstance(uset, str):
uset = I.UnionSet(uset)
elif isinstance(uset, I.Set):
uset = I.UnionSet.from_set(uset)
self.schedule = I.Schedule.from_domain(uset)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code, which converts a string or I.Set into an I.UnionSet, is duplicated in caten/polyhedral/stmt.py (lines 26-29). To adhere to the DRY (Don't Repeat Yourself) principle, this logic should be extracted into a shared helper function or a method.

Comment on lines +36 to +44
s_str = str(s)
if ":" in s_str:
tuple_part = s_str.split(":")[0].strip()
if tuple_part.startswith("{"):
tuple_part = tuple_part[1:].strip()
else:
tuple_part = s_str.strip()
if tuple_part.startswith("{") and tuple_part.endswith("}"):
tuple_part = tuple_part[1:-1].strip()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Extracting the tuple part by parsing the string representation of an isl.Set is fragile. This implementation relies on the specific format of isl.Set.__str__, which could change in future versions of the underlying library, breaking your code. If the caten.isl API provides a more robust way to achieve this (e.g., by inspecting the set's space and dimensions programmatically), it would be better to use that. If not, it would be good to add a comment here highlighting the fragility of this approach.

Comment on lines +67 to +77
if new_reads:
if dom.reads_map:
dom.reads_map = dom.reads_map.union(new_reads)
else:
dom.reads_map = new_reads

if new_writes:
if dom.writes_map:
dom.writes_map = dom.writes_map.union(new_writes)
else:
dom.writes_map = new_writes No newline at end of file

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for updating dom.reads_map and dom.writes_map is duplicated. This can be refactored into a small helper function to reduce repetition. Also, according to the new guideline in AGENTS.md, you should prefer the | operator over the .union() method.

Suggested change
if new_reads:
if dom.reads_map:
dom.reads_map = dom.reads_map.union(new_reads)
else:
dom.reads_map = new_reads
if new_writes:
if dom.writes_map:
dom.writes_map = dom.writes_map.union(new_writes)
else:
dom.writes_map = new_writes
if new_reads:
dom.reads_map = dom.reads_map | new_reads if dom.reads_map else new_reads
if new_writes:
dom.writes_map = dom.writes_map | new_writes if dom.writes_map else new_writes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant